import configparser
import time
from collections import defaultdict
from functools import partial
import numpy as np
import pandas as pd
from cplex_fair_assignment_lp_solver_large_cluster import fair_partial_assignment_large_cluster
from util.clusteringutil import (clean_data, read_data, scale_data,
                                 subsample_data, take_by_key,
                                 vanilla_clustering, write_fairness_trial)
from util.configutil import read_list
from util.probutil import deterministic_prob_vecs, form_class_prob_vector, sample_colors, create_prob_vecs, sample_colors_ml_model
from helpers_DS import get_center_colors_no_closing , get_center_colors_possible_closing
from range_DS_kcenter import fairKcenterRange 
from doubly_GF_functions import GF_to_GFDS 
from doubly_DS_functions import DS_to_GFDS 
from scipy.spatial.distance import cdist
from DS_algorithms import fairKcenter, fairKcenterPlusHeuristicB

# This function takes a dataset and performs a fair clustering on it.
# Arguments:
#   dataset (str) : dataset to use
#   config_file (str) : config file to use (will be read by ConfigParser)
#   data_dir (str) : path to write output
#   num_clusters (int) : number of clusters to use
#   deltas (list[float]) : delta to use to tune alpha, beta for each color
#   max_points (int ; default = 0) : if the number of points in the dataset 
#       exceeds this number, the dataset will be subsampled to this amount.
# Output:
#   None (Writes to file in `data_dir`)  
def fair_clustering_large_cluster(dataset, config_file, data_dir, num_clusters, deltas, max_points, L=0, p_acc=1.0, ml_model_flag=False):
    config = configparser.ConfigParser(converters={'list': read_list})
    config.read(config_file)

    # Read data in from a given csv_file found in config
    # df (pd.DataFrame) : holds the data
    df = read_data(config, dataset)

    # Subsample data if needed
    if max_points and len(df) > max_points:
       df = df.head(max_points)





    # Clean the data (bucketize text data)
    df, _ = clean_data(df, config, dataset)


    # variable_of_interest (list[str]) : variables that we would like to collect statistics for
    variable_of_interest = config[dataset].getlist("fairness_variable")


    
    # NOTE: this code only handles one color per vertex 
    assert len(variable_of_interest) == 1 

    # Assign each data point to a color, based on config file
    # attributes (dict[str -> defaultdict[int -> list[int]]]) : holds indices of points for each color class
    # color_flag (dict[str -> list[int]]) : holds map from point to color class it belongs to (reverse of `attributes`)
    attributes, color_flag, prob_vecs, prob_thresh = {}, {}, {}, {}  
    for variable in variable_of_interest:
        colors = defaultdict(list)
        this_color_flag = [0] * len(df)
        
        condition_str = variable + "_conditions"
        bucket_conditions = config[dataset].getlist(condition_str)

        # For each row, if the row passes the bucket condition, 
        # then the row is added to that color class
        for i, row in df.iterrows():
            for bucket_idx, bucket in enumerate(bucket_conditions):
                if eval(bucket)(row[variable]):
                    colors[bucket_idx].append(i)  # add the point to the list of its colors 
                    this_color_flag[i] = bucket_idx  # record the color for this given point  

        # NOTE: colors is a dict, this_color_flag is a list
        attributes[variable] = colors     
        color_flag[variable] = this_color_flag
        prob_vecs[variable] = deterministic_prob_vecs(len(df),len(colors),this_color_flag)



    # representation (dict[str -> dict[int -> float]]) : representation of each color compared to the whole dataset
    representation = {}

    for var in variable_of_interest:
        color_proportions = np.sum(prob_vecs[var],axis=0)/len(df)
        dict_ = {} 
        for j in range(color_proportions.shape[0]):
            dict_.update({j:color_proportions[j]})

        representation[var] = dict_ 



    # Select only the desired columns
    selected_columns = config[dataset].getlist("columns")
    df = df[[col for col in selected_columns]]


    
    # NOTE: this code only handles one membership criterion 
    ( _ , fair_vals), = representation.items()

    # NOTE: this handles the case when a color is missing in the sampled vertices 
    num_colors = max(fair_vals.keys())+1


    # Scale data if desired
    scaling = config["DEFAULT"].getboolean("scaling")
    if scaling:
        df = scale_data(df)

    # Cluster the data -- using the objective specified by clustering_method
    clustering_method = config["DEFAULT"]["clustering_method"]


    t1 = time.monotonic()


    # Step 1: Ordinary clustering algorithm is called 
    # clust_indices has the indices of the chosen clusters 
    initial_score, pred, cluster_centers, clust_indices = vanilla_clustering(df, num_clusters, clustering_method)

    t2 = time.monotonic()
    cluster_time = t2 - t1

    clust_indices_colorBlind = clust_indices
    clust_indices_GF = clust_indices



    # sizes (list[int]) : sizes of clusters
    sizes = [0 for _ in range(num_clusters)]
    for p in pred:
        sizes[p] += 1




    # dataset_ratio : Ratios for colors in the dataset
    dataset_ratio = {}
    for attr, color_dict in attributes.items():
        dataset_ratio[attr] = {int(color) : len(points_in_color) / len(df) 
                            for color, points_in_color in color_dict.items()}

    # fairness_vars (list[str]) : Variables to perform fairness balancing on
    fairness_vars = config[dataset].getlist("fairness_variable")

    # NOTE: here is where you set the upper and lower bounds 
    # NOTE: accross all different values within the same attribute you have the same multipliers up and down 
    for delta in deltas:
        #   alpha_i = a_val * (representation of color i in dataset)
        #   beta_i  = b_val * (representation of color i in dataset)
        alpha, beta = {}, {}
        a_val, b_val = 1 + delta , 1 - delta
        for var, bucket_dict in attributes.items():
            alpha[var] = {k : a_val * representation[var][k] for k in bucket_dict.keys()}
            beta[var] = {k : b_val * representation[var][k] for k in bucket_dict.keys()}



        fp_color_flag, fp_alpha, fp_beta = (take_by_key(color_flag, fairness_vars),
                                            take_by_key(alpha, fairness_vars),
                                            take_by_key(beta, fairness_vars))




        # Solves partial assignment and then performs rounding to get integral assignment
        t1 = time.monotonic()
        #res, nf_time = fair_partial_assignment_large_cluster(df, cluster_centers, fp_alpha, fp_beta, fp_color_flag, clustering_method, num_colors, L)
        resGF = fair_partial_assignment_large_cluster(df, cluster_centers, fp_alpha, fp_beta, fp_color_flag, clustering_method, num_colors, L)
        t2 = time.monotonic()
        GF_time = (t2 - t1) + cluster_time


        # fix alpha and beta 
        representation = representation[fairness_vars[0]] 
        alpha = alpha[fairness_vars[0]]
        beta = beta[fairness_vars[0]]

        # TAKE_NOTE 
        # DS: Setting things to be given to DS 
        X = df.to_numpy() 
        classTable = np.asarray(this_color_flag)
        centerUpperBound = np.zeros(num_colors)
        centerLowerBound = np.zeros(num_colors)
        constraints = np.zeros([2,3,2])


        #assert 0 == 1
        assert len(deltas) == 1  

        for col in range(num_colors):
            #centerUpperBound[col] = int(np.ceil((1.2)*representation[col]*num_clusters)) 
            centerUpperBound[col] = int(num_clusters) 
            centerLowerBound[col] = int(np.ceil((0.8)*representation[col]*num_clusters)) 


        if  not (np.sum(centerLowerBound) <= num_clusters):
            print(centerLowerBound)
            raise ValueError('DS Constraints are infeasible') 
        else:
            pass 
            #print('\n---------\nratio= rounded_sol_val / lp_sol_val = %f' %  ratio_rounded_lp )

        # ds_centers is a list of indices of the DS centers 
        t1 = time.monotonic()
        ds_center_indices = fairKcenterRange(X, classTable, centerUpperBound, centerLowerBound, num_clusters, metric='euclidean')
        t2 = time.monotonic()
        DS_time = t2 - t1


        # verify that DS is satisfied 
        DS_colors_ofcenters = get_center_colors_no_closing(ds_center_indices,this_color_flag,num_colors)


        for col in range(num_colors):
            assert DS_colors_ofcenters[col] >= centerLowerBound[col] and  DS_colors_ofcenters[col] <= centerUpperBound[col]

        assert sum(DS_colors_ofcenters) == num_clusters 



        # ds_distances is num_points x num_clusters 
        ds_distances = cdist(X,X[ds_center_indices,:],'euclidean')
        pred_DS = np.argmin(ds_distances,axis=1)  

        cost_DS = 0 
        x_ds = np.zeros((len(df),num_clusters)) 
        for index in range(len(df)): 
            x_ds[index,pred_DS[index]] = 1 
            if ds_distances[index,pred_DS[index]]>cost_DS: 
                cost_DS = ds_distances[index,pred_DS[index]] 



        x_assignment_ds = x_ds.ravel().tolist()
        #########################
        #########################
        #########################
        # Done with DS 




        # GF --> GFDS 
        t1 = time.monotonic()
        double_GF_clust_indices, doubly_GF_assignment, doubly_GF_cost , dGF_num_clusters_active , GF_emptycluster_flag = GF_to_GFDS(X, resGF["assignment"], this_color_flag, clust_indices_GF, num_clusters, num_colors, centerUpperBound, centerLowerBound)
        t2 = time.monotonic()
        doubly_GF_time = t2 - t1


        #########################
        #########################
        #########################
        # END: GF --> GFDS 




        # DS --> GFDS 
        # ds_centers is a list of records of the DS centers 
        ds_centers = [] 
        for index in ds_center_indices:
            ds_centers.append(df.iloc[index].tolist())

        # Part 1: 
        t1 = time.monotonic()
        res_doubly_DS = fair_partial_assignment_large_cluster(df, ds_centers, fp_alpha, fp_beta, fp_color_flag, clustering_method, num_colors, L)


        double_DS_clust_indices , doubly_DS_assignment , doubly_DS_cost , dDS_num_clusters_active , DS_emptycluster_flag = DS_to_GFDS(X, res_doubly_DS["assignment"], res_doubly_DS['objective'], this_color_flag, ds_center_indices, num_clusters, num_colors, centerUpperBound, centerLowerBound)
        t2 = time.monotonic()
        doubly_DS_time = t2 - t1

        # END: DS --> GFDS  

        ### Output / Writing data to a file
        # output is a dictionary which will hold the data to be written to the
        #   outfile as key-value pairs. Outfile will be written in JSON format.
        output = {}

        # num_clusters for re-running trial
        output["num_clusters"] = num_clusters
        output["num_colors"] =num_colors 




        # Save alphas and betas from trials
        output['prob_proportions'] = representation
        output["alpha"] = alpha
        output["beta"] = beta

        # Upper and Lower Bounds for the Colors of the Centers 
        output["DS_lowerBounds"] = list(centerLowerBound)
        output["DS_upperBounds"] = list(centerUpperBound)    



        # Record the colors of the points 
        output['color_flag'] = this_color_flag
        #  probability vecs
        for k,v in prob_vecs.items():
            prob_vecs = v 

        output['prob_vecs'] = prob_vecs.ravel().tolist()


        # Color-Blind Cost and Prediction 
        output["unfair_cost"] = initial_score
        output["unfair_assignments"] = pred
        output['unfair_DS_violation'] = 0 


        # GF Cost and Prediction 
        output["GF_cost"] = resGF["objective"]
        output["GF_assignment"] = resGF["assignment"]
        output['GF_DS_violation'] = 0 



        # DS Cost and Prediction 
        output["DS_cost"] = cost_DS 
        output["DS_assignment"] = x_assignment_ds
        output['GF_DS_violation'] = 0 


        # Doubly GF Cost and Prediction 
        output["doublyGF_cost"] = doubly_GF_cost 
        output["doublyGF_assignment"] = doubly_GF_assignment
        output['doublyGF_violation'] = 0 
        output['dGF_num_clusters_active'] = dGF_num_clusters_active
        output['GF_emptycluster_flag'] = GF_emptycluster_flag 


        # Doubly DS Cost and Prediction 
        output["doublyDS_cost"] = doubly_DS_cost 
        output["doublyDS_assignment"] = doubly_DS_assignment
        output['doublyDS_violation'] = 0 
        output['dDS_num_clusters_active'] = dDS_num_clusters_active
        output['DS_emptycluster_flag'] = DS_emptycluster_flag 

        

        # TAKE_NOTE: Get the colors of the centers for Color-Blind and GF 
        color_blind_center_colors= get_center_colors_no_closing(clust_indices_colorBlind,this_color_flag,num_colors)
        GF_center_colors= get_center_colors_possible_closing(resGF["assignment"],num_clusters,clust_indices_GF,this_color_flag,num_colors)
        DS_center_colors = get_center_colors_no_closing(ds_center_indices,this_color_flag,num_colors)
        doubly_GF_colors = get_center_colors_no_closing(double_GF_clust_indices,this_color_flag,num_colors)
        doubly_DS_colors = get_center_colors_no_closing(double_DS_clust_indices,this_color_flag,num_colors)


        output["ColorBlind_center_colors"] = color_blind_center_colors
        output["GF_center_colors"] = GF_center_colors
        output["DS_center_colors"] = DS_center_colors
        output["doublyGF_center_colors"] = doubly_GF_colors
        output["doublyDS_center_colors"] = doubly_DS_colors


        output['num_points'] = len(df) 
        output["scaling"] = scaling




        ###### Record times 
        output["ColorBlind_time"] = cluster_time
        output["GF_time"] = GF_time
        output["DS_time"] = DS_time
        output["doubly_GF_time"] = doubly_GF_time
        output["doubly_DS_time"] = doubly_DS_time

    

        # Writes the data in `output` to a file in data_dir
        write_fairness_trial(output, data_dir)

        # Added because sometimes the LP for the next iteration solves so 
        # fast that `write_fairness_trial` cannot write to disk
        time.sleep(1) 

        return output
